{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN-循环神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\ProgramData\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.python import debug as tf_debug\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_steps = 10\n",
    "batch_size = 200\n",
    "num_classes = 2\n",
    "state_size = 16\n",
    "learning_rate = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "就是按照文章中提到的规则，这里生成1000000个\n",
    "'''\n",
    "def gen_data(size=1000000):\n",
    "    X = np.array(np.random.choice(2, size=(size,)))\n",
    "    Y = []\n",
    "    '''根据规则生成Y'''\n",
    "    for i in range(size):   \n",
    "        threshold = 0.5\n",
    "        if X[i-3] == 1:\n",
    "            threshold += 0.5\n",
    "        if X[i-8] == 1:\n",
    "            threshold -=0.25\n",
    "        if np.random.rand() > threshold:\n",
    "            Y.append(0)\n",
    "        else:\n",
    "            Y.append(1)\n",
    "    return X, np.array(Y)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成batch数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_batch(raw_data, batch_size, num_step):\n",
    "    raw_x, raw_y = raw_data\n",
    "    data_length = len(raw_x)\n",
    "    batch_patition_length = data_length // batch_size                         # ->5000\n",
    "    data_x = np.zeros([batch_size, batch_patition_length], dtype=np.int32)    # ->(200, 5000)\n",
    "    data_y = np.zeros([batch_size, batch_patition_length], dtype=np.int32)    # ->(200, 5000)\n",
    "    '''填到矩阵的对应位置'''\n",
    "    for i in range(batch_size):\n",
    "        data_x[i] = raw_x[batch_patition_length*i:batch_patition_length*(i+1)]# 每一行取batch_patition_length个数，即5000\n",
    "        data_y[i] = raw_y[batch_patition_length*i:batch_patition_length*(i+1)]\n",
    "    epoch_size = batch_patition_length // num_steps                           # ->5000/5=1000 就是每一轮的大小\n",
    "    for i in range(epoch_size):   # 抽取 epoch_size 个数据\n",
    "        x = data_x[:, i * num_steps:(i + 1) * num_steps]                      # ->(200, 5)\n",
    "        y = data_y[:, i * num_steps:(i + 1) * num_steps]\n",
    "        yield (x, y)    # yield 是生成器，生成器函数在生成值后会自动挂起并暂停他们的执行和状态（最后就是for循环结束后的结果，共有1000个(x, y)）\n",
    "def gen_epochs(n, num_steps):\n",
    "    for i in range(n):\n",
    "        yield gen_batch(gen_data(), batch_size, num_steps)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义placeholder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.placeholder(tf.int32, [batch_size, num_steps], name=\"x\")\n",
    "y = tf.placeholder(tf.int32, [batch_size, num_steps], name='y')\n",
    "init_state = tf.zeros([batch_size, state_size])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RNN输入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_one_hot = tf.one_hot(x, num_classes)\n",
    "rnn_inputs = tf.unstack(x_one_hot, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义RNN cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.variable_scope('rnn_cell'):\n",
    "    W = tf.get_variable('W', [num_classes + state_size, state_size])\n",
    "    b = tf.get_variable('b', [state_size], initializer=tf.constant_initializer(0.0))\n",
    "    \n",
    "def rnn_cell(rnn_input, state):\n",
    "    with tf.variable_scope('rnn_cell', reuse=True):\n",
    "        W = tf.get_variable('W', [num_classes+state_size, state_size])\n",
    "        b = tf.get_variable('b', [state_size], initializer=tf.constant_initializer(0.0))\n",
    "    return tf.tanh(tf.matmul(tf.concat([rnn_input, state],1),W) + b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 将rnn cell添加到计算图中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = init_state\n",
    "rnn_outputs = []\n",
    "for rnn_input in rnn_inputs:\n",
    "    state = rnn_cell(rnn_input, state)  # state会重复使用，循环\n",
    "    rnn_outputs.append(state)\n",
    "final_state = rnn_outputs[-1]        # 得到最后的state\n",
    "\n",
    "#cell = tf.contrib.rnn.BasicRNNCell(num_units=state_size)\n",
    "#rnn_outputs, final_state = tf.contrib.rnn.static_rnn(cell=cell, inputs=rnn_inputs, \n",
    "                                                    #initial_state=init_state)\n",
    "#rnn_outputs, final_state = tf.nn.dynamic_rnn(cell=cell, inputs=rnn_inputs, \n",
    "#initial_state=init_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测，损失，优化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.variable_scope('softmax'):\n",
    "    W = tf.get_variable('W', [state_size, num_classes])\n",
    "    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))\n",
    "logits = [tf.matmul(rnn_output, W) + b for rnn_output in rnn_outputs]\n",
    "predictions = [tf.nn.softmax(logit) for logit in logits]\n",
    "\n",
    "y_as_list = tf.unstack(y, num=num_steps, axis=1)\n",
    "losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label,logits=logit) for logit, label in zip(logits, y_as_list)]\n",
    "total_loss = tf.reduce_mean(losses)\n",
    "train_step = tf.train.AdagradOptimizer(learning_rate).minimize(total_loss)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "epoch 0\n",
      "第 100 步的平均损失 0.5057216426730156\n",
      "第 200 步的平均损失 0.4789837658405304\n",
      "第 300 步的平均损失 0.4732507354021072\n",
      "第 400 步的平均损失 0.46972491800785066\n",
      "\n",
      "epoch 1\n",
      "第 100 步的平均损失 0.4703376641869545\n",
      "第 200 步的平均损失 0.4640725874900818\n",
      "第 300 步的平均损失 0.46390075027942657\n",
      "第 400 步的平均损失 0.4624347412586212\n",
      "\n",
      "epoch 2\n",
      "第 100 步的平均损失 0.46826018780469897\n",
      "第 200 步的平均损失 0.4612054255604744\n",
      "第 300 步的平均损失 0.460566953420639\n",
      "第 400 步的平均损失 0.4593639954924583\n",
      "\n",
      "epoch 3\n",
      "第 100 步的平均损失 0.4652071687579155\n",
      "第 200 步的平均损失 0.4578261461853981\n",
      "第 300 步的平均损失 0.4586115163564682\n",
      "第 400 步的平均损失 0.45804080605506897\n",
      "\n",
      "epoch 4\n",
      "第 100 步的平均损失 0.4658902403712273\n",
      "第 200 步的平均损失 0.4588862270116806\n",
      "第 300 步的平均损失 0.45785953253507616\n",
      "第 400 步的平均损失 0.45685730397701263\n",
      "0.5057216426730156\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD8CAYAAACb4nSYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xl8VNX9//HXJxtZScjCHkjCjrIIEcF9QUVRsO7aRavWWuvDarUu1Vq/2tZqtdZarVWr3exPwWpLFUXBfQEJFmSHEIIJYEJYEkgI2c7vj5lgCAkZsk0y9/18PPLIzJ1zZz65DO+5c+6555pzDhER8YawYBcgIiKdR6EvIuIhCn0REQ9R6IuIeIhCX0TEQxT6IiIeotAXEfEQhb6IiIco9EVEPCQi2AU0lpqa6jIyMoJdhohIt7JkyZIS51xaS+26XOhnZGSQk5MT7DJERLoVM9sUSDt174iIeIhCX0TEQxT6IiIeotAXEfEQhb6IiIco9EVEPEShLyLiISET+pt37eWRt9by5faKYJciItJlhUzol+2t5vF3cllWuCvYpYiIdFkhE/oZKXEAbCwpD3IlIiJdV8iEfkxUOP0To8lX6IuINCtkQh8gMy2OPIW+iEizQiv0U+PI27YH51ywSxER6ZJCLPTjKausYWdFdbBLERHpkkIq9LNS6w/m7glyJSIiXVNIhX6GP/TztqlfX0SkKSEV+gN7xRARZhq2KSLSjJAK/cjwMAYlxyr0RUSaEVKhD74RPAp9EZGmhWzo19Vp2KaISGOhF/ppceyrqWNrWWWwSxER6XJCL/T9I3g0HYOIyMFCLvSzUuMBNB2DiEgTQi70+/TsQUxkOBs1Vl9E5CAhF/pm5j+Yq7NyRUQaC7nQB9/BXA3bFBE5WEiGflZqHAU791JVUxfsUkREupSQDP2MlDhq6xwFO3W9XBGRhgIKfTObZmZrzSzXzO5o4vErzWybmS31/1zT4LErzGy9/+eK9iy+OZlp/tk2dTBXROQAES01MLNw4AngdKAQWGxmc5xzqxo1fck5d0OjdZOBnwPZgAOW+Nfd2S7VN+PrKZYV+iIiDQWypz8JyHXO5TnnqoAXgZkBPv+ZwNvOuR3+oH8bmNa6UgOXFBtFr9hIjdUXEWkkkNAfABQ0uF/oX9bYBWb2hZm9bGbph7OumV1rZjlmlrNt27YASz80DdsUETlYIKFvTSxrPJvZf4EM59xYYD7w18NYF+fc0865bOdcdlpaWgAltSwzNZ78Eh3IFRFpKJDQLwTSG9wfCGxp2MA5t905t89/9xlgYqDrdpSstDi+KqukfF9NZ7yciEi3EEjoLwaGmVmmmUUBlwJzGjYws34N7s4AVvtvzwPOMLNeZtYLOMO/rMPtn3htu/r1RUTqtTh6xzlXY2Y34AvrcOA559xKM7sPyHHOzQFuNLMZQA2wA7jSv+4OM7sf3wcHwH3OuR0d8HccJLPBCJ4j+id2xkuKiHR5LYY+gHNuLjC30bJ7Gty+E7izmXWfA55rQ42tkpGisfoiIo2F5Bm5ADFR4fRPjNZYfRGRBkI29MF3Zq7G6ouIfC2kQz8jJY68bXtwTtfLFRGBEA/9zNQ4yipr2FlRHexSRES6hJAO/az6idd0Zq6ICBDioZ9Zf71cjeAREQFCPPQH9oohIsx0gpaIiF9Ih35keBiDkmM1bFNExC+kQx98B3PVvSMi4uOJ0M/fXk5dnYZtioiEfuinxVFZXcdXZZXBLkVEJOhCP/R16UQRkf1CPvSz6odtKvRFREI/9Pv07EFMZLhm2xQRwQOhb2Zk6Hq5IiKAB0IfICs1Tn36IiJ4JPQzU+Mo2LmX6tq6YJciIhJUngn92jpHwY6KYJciIhJU3gj9NA3bFBEBj4R+lsbqi4gAHgn9pNgoesVGaqy+iHieJ0IffP36GqsvIl7nodCPV/eOiHieZ0I/Ky2Or8oqKd9XE+xSRESCxjOhXz/xmq6iJSJe5pnQz0jRCB4REe+EfmosgA7mioineSb0Y6Mi6JcYzUZ174iIh3km9ME/bFPdOyLiYQp9EREP8Vzo76qoZmd5VbBLEREJioBC38ymmdlaM8s1szsO0e5CM3Nmlu2/H2Vmz5vZcjNbZmYnt1PdrZLln3hN0zGIiFe1GPpmFg48AZwFjAYuM7PRTbRLAG4EFjVY/D0A59wY4HTgETML2reLTP/1ctXFIyJeFUgATwJynXN5zrkq4EVgZhPt7gceAiobLBsNLABwzhUDu4DsNlXcBgN7xRARZrp0ooh4ViChPwAoaHC/0L9sPzM7Ckh3zr3WaN1lwEwzizCzTGAikN6GetskMjyMQcmx2tMXEc+KCKCNNbHM7X/Q113zKHBlE+2eA0YBOcAm4BPgoMlvzOxa4FqAQYMGBVBS62WmxpGnE7RExKMC2dMv5MC984HAlgb3E4AjgffMLB+YDMwxs2znXI1z7mbn3Hjn3EwgCVjf+AWcc08757Kdc9lpaWmt/VsCkpkaR/72curqXMuNRURCTCChvxgYZmaZZhYFXArMqX/QOVfqnEt1zmU45zKAhcAM51yOmcWaWRyAmZ0O1DjnVrX/nxG4jNQ4Kqvr+KqssuXGIiIhpsXuHedcjZndAMwDwoHnnHMrzew+IMc5N+cQq/cG5plZHbAZ+HZ7FN0W9ZdOzC8pp39STJCrERHpXIH06eOcmwvMbbTsnmbantzgdj4wovXltb/MBmP1jx2aGuRqREQ6l6fOyAXokxBNTGS4RvCIiCd5LvTDwowMzcEjIh7ludAHX7++Ql9EvMiToZ+ZGseXOyqorq0LdikiIp3Ks6FfW+co2FER7FJERDqVN0M/TdfLFRFv8mTo14/VV+iLiNd4MvSTYqPoFRupefVFxHM8Gfrgv3SiJl4TEY/xbOhn+CdeExHxEs+GflZqHFtLK6moOmimZxGRkOXZ0K+/dGJ+iYZtioh3eDj0NYJHRLzHs6GfkRoLoOvlioineDb0Y6Mi6JcYrWGbIuIpng198A/bVOiLiIco9BX6IuIhng/9XRXV7CyvCnYpIiKdwtOhn9Xg0okiIl7g6dCvH6uvLh4R8QpPh/7AXjGEhxn5Cn0R8QhPh35keBiDkmO1py8inuHp0AffwVz16YuIVyj0U+PILymnrs4FuxQRkQ6n0E+NY291LUW7K4NdiohIh/N86O+/dKIuqCIiHuD50M/UWH0R8RDPh36fhGhiIsM1gkdEPMHzoR8WZmRoDh4R8QjPhz74+vUV+iLiBQp9fCN4vtxRQXVtXbBLERHpUAp9fKFfW+co3Lk32KWIiHSogELfzKaZ2VozyzWzOw7R7kIzc2aW7b8faWZ/NbPlZrbazO5sr8LbU8b+6+Xq0okiEtpaDH0zCweeAM4CRgOXmdnoJtolADcCixosvgjo4ZwbA0wEvm9mGW0vu33Vj9XP01h9EQlxgezpTwJynXN5zrkq4EVgZhPt7gceAhqe2uqAODOLAGKAKqCsbSW3v15xUSTFRupgroiEvEBCfwBQ0OB+oX/ZfmZ2FJDunHut0bovA+XAVuBL4GHn3I7GL2Bm15pZjpnlbNu27XDqbze6dKKIeEEgoW9NLNs/O5mZhQGPArc00W4SUAv0BzKBW8ws66Anc+5p51y2cy47LS0toMLbm0JfRLwgkNAvBNIb3B8IbGlwPwE4EnjPzPKBycAc/8Hcy4E3nXPVzrli4GMguz0Kb29ZqXFsLa2koqom2KWIiHSYQEJ/MTDMzDLNLAq4FJhT/6BzrtQ5l+qcy3DOZQALgRnOuRx8XTqnmk8cvg+ENe3+V7SD+ksn5pdUBLkSEZGO02LoO+dqgBuAecBqYJZzbqWZ3WdmM1pY/QkgHliB78PjeefcF22suUOM7JcAwL+Xbg5yJSIiHScikEbOubnA3EbL7mmm7ckNbu/BN2yzyxuSFs9lkwbxzId5TB3Vh0mZycEuSUSk3emM3Abunj6KQcmx3DJ7KXv2qW9fREKPQr+BuB4R/PbicWzeuZf7/7sq2OWIiLQ7hX4jEwcnc91JQ3gpp4C3VxUFuxwRkXal0G/CTVOHM6pfT+585Qu279kX7HJERNqNQr8JURFh/O6S8ZTtreGnry7HOdfySiIi3YBCvxkj+iZw65nDmbeyiH99rmGcIhIaFPqHcPXxWUzKTObeOSsp3KmTtkSk+1PoH0J4mPHIReNwznHr7GXU1ambR0S6N4V+C9KTY/n5uUewMG8Hz328MdjliIi0iUI/ABdlD2TqqD48NG8t64p2B7scEZFWU+gHwMz49QVjSOgRwU0vLqWqRhdQF5HuSaEfoNT4Hjxw/hhWbS3j9wvWB7scEZFWUegfhjOO6MtFEwfy5Hu5LNm0M9jliIgcNoX+Ybrn3NH0S4zhlllLdcEVEel2FPqHKSE6kkcuHsemHRX88vXVwS5HROSwKPRbYXJWCtccn8kLi77k3bXFwS5HRCRgCv1WuuWMEYzok8BtL3/BzvKqYJcjIhIQhX4rRUeG89tLxrGrooq7/71Ck7KJSLeg0G+DI/onctPU4by+fCv/Wbol2OWIiLRIod9G3z8xiwmDkvjZf1awtXRvsMsRETkkhX4bRYSH8duLx1Nb5/jJ7C80KZuIdGkK/XaQkRrHXdNH8VFuCXf/ZwW1Cn4R6aIigl1AqLh80iAKduzlqfc3sG33Pn5/6VHERIUHuywRkQNoT7+dmBl3nDWSe88dzfzVRXzz2YUayikiXY5Cv51deVwmT14+gRVbyrjgqU8o2KErbolI16HQ7wBnjenHP64+hpLd+zj/j5+wYnNpsEsSEQEU+h1mUmYy//rBsUSGGZf86VM+XL8t2CWJiCj0O9KwPgm8cv1xpCfH8t3nF/PK54XBLklEPE6h38H6JkYz67opHJ2RzI9nLeOP723QlA0iEjQK/U7QMzqSv1x1NOeO68+Db67h3jkrNZZfRIJC4/Q7SY+IcB67ZDx9e/bgmQ83UlS2j99dOp7oSI3lF5HOE9CevplNM7O1ZpZrZnccot2FZubMLNt//5tmtrTBT52ZjW+v4rubsDDjrumj+dk5o5m36iu+/edF7KrQWH4R6Twthr6ZhQNPAGcBo4HLzGx0E+0SgBuBRfXLnHMvOOfGO+fGA98G8p1zS9ur+O7q6uMzefyyo1hWUMqFT33K5l1tm6htz74avijcxZsrtlJZXdtOVYpIKAqke2cSkOucywMwsxeBmcCqRu3uBx4Cbm3meS4D/l8r6ww554ztT2p8D773txzOf/Jjnr9yEqP792y2fV2dY0vpXjZsKydv2x42bNtD3rZyNmzbQ1HZvv3tvjNlMPfNPLIz/gQR6YYCCf0BQEGD+4XAMQ0bmNlRQLpz7jUzay70L8H3YSF+k7NSePm6Y7niuc+4+E+f8vS3JzIuPYmNJb4w/zrgy9lYsofK6rr96/aMjmBI73iOH5rGkN5xDEmL5901xfx94SZmjh/AxMG9gviXiUhXFUjoWxPL9g89MbMw4FHgymafwOwYoMI5t6KZx68FrgUYNGhQACWFjhF9E3jl+mO58vnP+OafF9FwNGeYwcBesQxJi+PYISkMSYtnSFocWWnxpMZHYXbgP81xQ1P5YN02fvrKcl678XgiwzU4S0QOFEjoFwLpDe4PBBpeJioBOBJ4zx9CfYE5ZjbDOZfjb3Mph+jacc49DTwNkJ2d7bmxjP2TYph93bE8/cEGoiPCGdI7niFp8QxOiT2s0T3xPSK4b+aRXPO3HJ7+II8fnjK0A6sWke4okNBfDAwzs0xgM74Av7z+QedcKZBaf9/M3gNurQ98/zeBi4AT26/s0JMYE8lPzhzZ5ueZOroPZ4/py2ML1nP2mH5kpsa1Q3UiEipa/P7vnKsBbgDmAauBWc65lWZ2n5nNCOA1TgQK6w8ES8e799wj6BERxl2vLtfZvyJyAOtqoZCdne1ycnJabiiH9MKiTdz16goevmgcF04cGOxyRKSDmdkS51x2S+10pC9EXXb0ILIH9+IXr69i+559La8gIp6g0A9RYWHGA+ePoXxfDb98fXWwy+lwxWWVTP/9h7z2xZaWG4t4mEI/hA3rk8APThrCK//bHNLz+TvnuPOV5azcUsa9c1ayZ19NsEsS6bIU+iHu+lOGkpUax12vrmBvVWhO0TB7SSEL1hRz4cSBlOyp4qn3NgS7JJEuS6Ef4qIjw/nV+WP4ckcFjy1YH+xy2l3hzgru++8qjslM5qELxnLe+P4882Fem+czEglVCn0PmJyVwiXZ6TzzYR6rtpQFu5x2U1fnuO3lL6hzjocvGkdYmPGTab5zHX7z5pogVyfSNSn0PeLOs0fSKzaSO19dHjIXcPnHok18smE7d08fTXpyLAADkmK45oRM/r10C8sKdgW5QpGuR6HvEUmxUfzsnNEsK9jF3z/ND3Y5bbaxpJwH5q7hxOFpXDYp/YDHfnDyUFLjo/jF66t0cppIIwp9D5kxrj8nDU/jN/PWsqUb93nX1jlunb2MyHDjoQvGHjTxXHyPCH58+ggW5+/kzRVfBalKka5Joe8hZsYvzjuSOgf3/GdFt90LfvbDPJZs2sn/zTyCvonRTba5OHsgI/ok8Os317CvJjRHLYm0hkLfY9KTY7n59GHMX13cLfeC1xXt5pG31nHmEX04b/yAZttFhIdx1/RRbNpewd8/3dSJFUqwrCvarcuPBkCh70FXHZfJ6H49+fmclZRVVge7nIBV19bx41lLiY+O4JffGHNQt05jJw5P46Thafx+wXp2lisMQllxWSUz/vARVz6/uNt+g+0sCn0PiggP49cXjKFkzz4e6kZDG594N5cVm8v41TeOJDW+R0Dr3DV9FHv21YTkOQrytT++v4HK6jqWFuxizjJNxXEoCn2PGjswie8el8k/Fn5JTv6OYJfTouWFpfzhnVzOG9+faUf2C3i94X0SuGzSIP6xcBMbtu3pwAqDyznHXz7eyLqi3cEupdMVlVXywqIvuWDCQI7o35MH31hDZbWO4zRHoe9hPz59OAOSYrjzleVU1dS1vEKQVFbXcsvspaTER/F/Mw7/ou83nz6c6MhwHpjbfb7VHK7ZSwq597+r+NGLS6kLkfMwAvXH9zZQV+f40WnDuHv6aLaUVvLnjzYGu6wuS6HvYXE9Irj/vCNYX7yHP73fdeereXT+OtYV7eHXF4wlMTbysNdPje/B9acMYf7qIj7ZUNIBFQbX5l17ue+/q+id0IPVW8v4r4dmGv2qtJJ/fubbyx+UEsuUISmcMboPT76bS/HuymCX1yUp9D3u1JF9mD62H4+/m0teF+z+WLJpB09/kMdlk9I5ZUTvVj/PVcdlMiAphl+8tjpkzkgG31QUt/unoph93RRG9+vJw2+t7dLf3NrTk+/lUlfnuOHUr68HfefZo6iqreOReeuCWFnXpdAXfn7uaHpEhPHTLnZ5xYqqGm6ZtYwBSTHcNX10m54rOjKc26aNYNXWMl75vLCdKgy+FxZt4qPcEu6aPorBKXHcftZICnbs5Z+LQn+Y6pZde3nxswIuyk7fPw0HQGZqHN+ZksGsJQUhNddUewnkwugS4nonRPPTs0dx5yvLGfLTucRGRRATFU5sVDixURH+3+HERPp/N1hW/3hMVDjxPSKYlJkc8Mialjz4xhryt1fw4rWTie/R9rfqjHH9ef7jfB5+ay3Tx/YjNqp7v/03bS/nV3PXcMKwVC6fNAiAE4elMiUrhcffyeXC7PR22W5d1RPv5uI4cC+/3o2nDuOVzwv5xeureOGaY1oc3uslofuOkMNySXY6EWHGpu0VVFTVsre6hoqqWv9PDXv21bBt9779y/ZW1VBRXUvjLwbhYcbJw9O4YOJAThvVmx4R4a2q5+PcEv766SauOi6TyVkp7fAX+s5I/tk5o7jgj5/y9Ad53DR1eLs8bzDUT0UREW48dOHXU1GYGXecNZKZT3zMMx/kcfPp3fdvPJTCnRXMying4ux0BiTFHPR4YmwkN00dzs/nrGT+6mJOH90nCFV2TQp9AXyXV7woO73lhg0459hXU0f5Pt8HxI7yKuau2Mq//7eZBWuKSYyJ5Jyx/bhg4kCOSk8KeG+rrLKa217+gqy0OG6bNqI1f06zJg5OZvqYfvzp/TwumzSIPj2bnsahq3v+440szt/JIxeNo1/igaE3Lj2Js8f05dkP8/jW5MGkJbTPN6+u5Il3N2AYPzzl4L38epcfM4i/fZrPr+au5qThaURFqDcb1KcvbWBmREeGkxLfg/TkWMalJ3HnWaP45I7T+OtVkzhpeBovLynk/Cc/4bTfvs8T7+YGNNHbL15bxdbSvTxy0TiiI1v3TeFQbp82kto6x8Pz1rb7c3eG3OLdPDRvLVNH9eH8CU1PRXHrGSOorKnjD++E3klpBTsqmJ1TwCVHp9O/ib38epH+qTg2lpTzj4Whf4wjUAp9aXfhYeab/uCyo8i5eyoPXjCG1Lge/GbeWo578B0uf2Yh/1pSSHkT17JdsLqIWTmF/ODkIRw1qFeH1DcoJZbvHpfBy58XsmJzaYe8Rkepqa3jllnLiIsK51fnH9nst6estHguOTqdf372JV9ur+jkKjvWE+/mEhZmXH/KkBbbnjKiNycMS+WxBes1L4+fQl86VEJ0JJccPYhZ103hg5+cwo9OG0bhzr3cMnsZR/9yPrfMWsYnuSXU1Tl2lldxxyvLGdk3gRtPG9ahdV1/ylB6xUbxy9dXd6kRSy156v0NLCss5f7zjqR3wqG7pn502jDCw4xH3u6e32ia8uX2Cl5eUsjlkwYd1K3VFDPj7umj2V1Zze/mh963ntZQ6EunGZQSy01Th/P+T05m1vencO7Y/sxb+RWXP7uIEx56lyuf/4xdFVX89uLxrT4AHKjEmEhumjqMT/O2M391cYe+VntZtaWMxxas55yx/ThnbP8W2/fpGc3Vx2fyn6Vbut03muY8/s56wsOMH5zc8l5+vRF9E7jUA1NxBEqhL53OzJiUmcyDF45l8V1TeezS8QzpHc/yzaX8+PQRjO7fs1PquGzSIIakxfHA3NVU13btk5mqanwzjCbGRHH/zMCnovj+SUNIio3koW56/KKh/JJyXvnfZi4/5vAPwN88tX4qjtUdVF33odCXoIqJCmfm+AH87apJrPy/aYe1B9dW9Qf68krKeaGLH+h7/J31rPlqNw+cP4ZecVEBr9czOpIfnjyUD9Zt45Pc7j0FxePv5BIRZvzgpMN/j6Ql9OCHpwxl/upiPu7m26GtFPrSZcREdWyXTlNOGdGb44am8LsF6ymt6JrXFlhWsIsn39vABRMGtmq8+benDKZ/YjQPvrmmWx2/aGhjSTmv/q+Qb00eTO9WDrP97nEZDOwVw/2vrQqpqTgOl0JfPM3MuOvs0ZTureYP73a9A32+GUaX0TuhB/ec27qpKKIjw7n59OEsKyzljW54tTSAxxesJyoijOtasZdfLzoynDvOGsmar3YzO6egHavrXhT64nmj+/fk4onp/OWTfF77Ygt52/ZQ00X6+B95ay25xXt48IKxJMYc/gyj9c6fMJDhfeJ5eN7aLn/8orEN2/bw76Wb+c6UjDafaDZ9TD+yB/fi4bfWsaeJIcNeoDNyRYBbzhjO/NVF3PDP/wEQFR5GZmocQ3vHM6R3PEN7xzOsdzyZqXEdcsJYUz7buINnP9rIN48ZxInD09r0XOFhxm1njuSav+UwK6eAbx4zuJ2q7Hi/X7CeHhHhXHtiVpufyzcVx2hmPvExT76by23TRrZDhd2LQl8E6N0zmg9vP4V1RXvILd7D+uLdbCjew8otpbyxYiv1XcBh5ru4/NC0eIb2iff99n8oJES3fk+8sfJ9Ndw6exkDe8Xw07NHtctznjaqN9mDe/HY/PV846gB3WLCudzi3cxZtoVrT8xqt4n8xqUn8Y2jBvDsRxu5/JhBDOwV2/JKISSgf3UzmwY8BoQDzzrnft1MuwuB2cDRzrkc/7KxwJ+AnkCd/zFd3UC6nNioCManJzE+PemA5ZXVtWwsKSe3eM8BPx+uL6GqQVdJ357RHJOVzNRRfThpRBo92/Ah8Os31lCws4IXvzeZuHaaKbN+MrYLn/qU5z/OP+S8NV3FYwtyiYkM5/sntu+orp+cOYI3VmzlwTfX8vhlR7Xrc3d1Lb6bzCwceAI4HSgEFpvZHOfcqkbtEoAbgUUNlkUA/wC+7ZxbZmYpQNccIiHSjOjIcEb168mofgeeP1BTW0fBzr2sL9pN7rY9rP1qNx+tL+E/S7cQEWb7PwCmjupzwHzvLflofQl/X7iJq4/P5Jh2mmG0XnaGr6an3tvA5ZMGHdbwz862rmg3r32xhetOGkJyO9fZPymGa0/I4vfv5HLlsRlMHNwxU350RdbSEC4zmwLc65w703//TgDn3AON2v0OmA/cCtzqnMsxs7OBy51z3wq0oOzsbJeTk3N4f4VIF1Fb51hasJO3VxUzf3URucW+M0BH9k3wfQCM7sPYAYmEhTU9Z05ZZTXTHv2A6Khw5t54QoccP1hXtJtpv/uAq4/PbPPFaTrSD//5Oe+tKeaj20/tkA+n8n01nPLwe/RPiuGVHxzb7L9Jd2FmS5xz2S21C2T0zgCg4fimQv+yhi92FJDunHut0brDAWdm88zsczO7rZlirzWzHDPL2bZtWwAliXRN4WHGxMHJ3HHWSOb/+CTeu/Vk7p4+iqTYSP74/gbOe+JjjnlgAXf86wvmrypib1XtAev/4rVVfFVW2WEzjAIM75PA+RMG8tdPNrE5gFlPg2HNV2XMXb6VK4/L6LBvI3E9IvjJmSNYWrDLU9cVDqSzsKmPv/1fD8wsDHgUuLKZ5z8eOBqoABb4P40WHPBkzj0NPA2+Pf2AKhfpBjJS47jmhCyuOSGLXRVVvL9uG2+vKuL1L7by4uICoiPDOH5oKlNH9SEyPIxZOYVc34EzjNa7+fThzFm2hUffXsfDF43r0NdqjcfmrycuKoLvndD2ETuHcsGEgfzlk3wefGMNZx7Rt9NGZgVTIKFfCDS8usZAoOHHYgJwJPCef5rXvsAcM5vhX/d951wJgJnNBSYAB4S+iBckxUYxc/wAZo4fQFVNHZ9t3MH81UW8vapo/6R3OtRbAAAJqUlEQVRvI/sm8KOpHTvDKMCApBiumDKYP3+0ke+dkMWIvgkd/pqBWrWljDdWfMWNpw4lKbZjjzmEhfmGcF769EKe/TCPG07t+G0fbIF07ywGhplZpplFAZcCc+ofdM6VOudSnXMZzrkMYCEwwz96Zx4w1sxi/Qd1TwJWHfwSIt4SFRHG8cNSuXfGEXx0+ym8edMJ3HX2KJ761sQOn2G03vUnDyUuKoLfzFvTKa8XqMcWrCMhOoKrj+/Yvfx6k7NSOPOIPjz53gaKd4f+wMIWQ985VwPcgC/AVwOznHMrzew+/978odbdCfwW3wfHUuBz59zrbS9bJHSYGSP79uR7J2aRkRrXaa/bKy6K604ewvzVxSzO39Fpr3soK7eUMm9lEVcdl0libPud99CSO88aRXVtHWc/9iE/evF/vLT4Swp2hNbFZ+q1OHqns2n0jkjn2VtVy0m/eZdBybHMvm5KwNcxBijeXcnKzWUs31zK8s2lrNpSRnRkGBkpcWSk+n4yU+LISI2lf2JMQKNjvve3HBbmbeej209t07QTrfHBum386/NCPtmwnW279wGQnhzDsVmpHDs0hSlDUlq8cE0wBTp6p+ufkiciHSYmKpybpg7np68uZ/7q4mZn8Swqq2R5YSkrtpSywh/yRWW+YDSDzNQ4JgzuRVVNLfklFXyUW8K+mq9PXIuKCGNwcqzvwyAltsEHQhx9e0YTFmYsLyzl7VVF3Dx1eKcHPsCJw9M4cXgazjlyi/fwcW4Jn2zYzhsrtvKSf4K24X3iOXZIKlOGpDA5KyUodbaV9vRFPK6mto4zHv2A8DDjzZtOpHi3P+D94b5iS9n+PV8zGJIWz5gBiRzRvydjBiQyun/Pg6agqKtzFO2uZGNJOfklFeRvL/ffLmfTjgqqGnwgREeGMTg5jorqGkorqvnojlPbdDZze6utc6zcUsonG7bzcW4Ji/N3UFldR5jBkQMSOXZIKscOSSE7o1dQp7YIdE9foS8izF2+letf+JyE6Ah2V/pmnwwzGNo7niMHJHJk/0TGDExkdL+ebZ4Woq7OsbWskvySrz8I8reX8+WOCq44NqPLTwa3r6aWZQWlfJxbwqcbtvO/gp1U1zoiw42j0nsxOSuZyUNSmDCoV6cOAVXoi0jAnHP8fM5KKqpqObJ/T8YMTGRUv57dYlK2YKuoqmFx/k4+yS3h07ztrNhcSp3zdWlNGJTElCxfd9C49MQOHZml0BcRCYKyymoWb9zBpxu282nedlZtLcM5XzfWxMG9mJLlOyg8dmASkeHtd0kThb6ISBewq6KKzzbu4NO87Xy6YTtrvtoNQGxUONkZyUzOSmZKVgpjBiQS0YYPAYW+iEgXtKO8ikV521mY5/smsK7INylffI8ILpuU3upJ8DRkU0SkC0qOi+KsMf04a0w/AEr27GOh/0OgX2JMh7++Ql9EJIhS43twztj+nDO2f6e8ni6MLiLiIQp9EREPUeiLiHiIQl9ExEMU+iIiHqLQFxHxEIW+iIiHKPRFRDyky03DYGbbgE1teIpUoKSdyukIqq9tVF/bqL626cr1DXbOpbXUqMuFfluZWU4g808Ei+prG9XXNqqvbbp6fYFQ946IiIco9EVEPCQUQ//pYBfQAtXXNqqvbVRf23T1+loUcn36IiLSvFDc0xcRkWZ0y9A3s2lmttbMcs3sjiYe72FmL/kfX2RmGZ1YW7qZvWtmq81spZn9qIk2J5tZqZkt9f/c01n1Nagh38yW+1//oEuVmc/v/dvwCzOb0El1jWiwXZaaWZmZ3dSoTadvPzN7zsyKzWxFg2XJZva2ma33/+7VzLpX+NusN7MrOrG+35jZGv+/36tmltTMuod8L3Rgffea2eYG/45nN7PuIf+/d2B9LzWoLd/Mljazbodvv3blnOtWP0A4sAHIAqKAZcDoRm2uB57y374UeKkT6+sHTPDfTgDWNVHfycBrQd6O+UDqIR4/G3gDMGAysChI/9Zf4Rt/HNTtB5wITABWNFj2EHCH//YdwINNrJcM5Pl/9/Lf7tVJ9Z0BRPhvP9hUfYG8FzqwvnuBWwN4Dxzy/3tH1dfo8UeAe4K1/drzpzvu6U8Ccp1zec65KuBFYGajNjOBv/pvvwycZmbWGcU557Y65z73394NrAYGdMZrt7OZwN+cz0Igycz6dXINpwEbnHNtOVmvXTjnPgB2NFrc8H32V+C8JlY9E3jbObfDObcTeBuY1hn1Oefecs7V+O8uBAa29+sGqpntF4hA/r+32aHq82fHxcD/a+/XDYbuGPoDgIIG9ws5OFT3t/G/6UuBlE6prgF/t9JRwKImHp5iZsvM7A0zO6JTC/NxwFtmtsTMrm3i8UC2c0e7lOb/owV7+wH0cc5tBd+HPdC7iTZdYTsCXIXvm1tTWnovdKQb/N1PzzXTPdYVtt8JQJFzbn0zjwdz+x227hj6Te2xNx6CFEibDmVm8cC/gJucc2WNHv4cX5fFOOBx4N+dWZvfcc65CcBZwA/N7MRGjwd1G5pZFDADmN3Ew11h+wWqK7wX7wJqgBeaadLSe6Gj/BEYAowHtuLrQmks6NsPuIxD7+UHa/u1SncM/UIgvcH9gcCW5tqYWQSQSOu+WraKmUXiC/wXnHOvNH7cOVfmnNvjvz0XiDSz1M6qz/+6W/y/i4FX8X2NbiiQ7dyRzgI+d84VNX6gK2w/v6L6Li//7+Im2gR1O/oPHJ8DfNP5O6AbC+C90CGcc0XOuVrnXB3wTDOvG+ztFwGcD7zUXJtgbb/W6o6hvxgYZmaZ/r3BS4E5jdrMAepHSVwIvNPcG769+fv//gysds79tpk2feuPMZjZJHz/Dts7oz7/a8aZWUL9bXwH/FY0ajYH+I5/FM9koLS+K6OTNLt3Fezt10DD99kVwH+aaDMPOMPMevm7L87wL+twZjYNuB2Y4ZyraKZNIO+Fjqqv4TGibzTzuoH8f+9IU4E1zrnCph4M5vZrtWAfSW7ND76RJevwHdW/y7/sPnxvboBofN0CucBnQFYn1nY8vq+fXwBL/T9nA9cB1/nb3ACsxDcSYSFwbCdvvyz/ay/z11G/DRvWaMAT/m28HMjuxPpi8YV4YoNlQd1++D6AtgLV+PY+r8Z3nGgBsN7/O9nfNht4tsG6V/nfi7nAdzuxvlx8/eH178P6EW39gbmHei90Un1/97+3vsAX5P0a1+e/f9D/986oz7/8L/XvuwZtO337teePzsgVEfGQ7ti9IyIiraTQFxHxEIW+iIiHKPRFRDxEoS8i4iEKfRERD1Hoi4h4iEJfRMRD/j85khvK7MEi4AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x21270fb5d68>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def train_rnn(num_epochs, num_steps, state_size=4, verbose=True):\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        #sess = tf_debug.LocalCLIDebugWrapperSession(sess)\n",
    "        training_losses = []\n",
    "        for idx, epoch in enumerate(gen_epochs(num_epochs, num_steps)):\n",
    "            training_loss = 0\n",
    "            training_state = np.zeros((batch_size, state_size))   # ->(200, 4)\n",
    "            if verbose:\n",
    "                print('\\nepoch', idx)\n",
    "            for step, (X, Y) in enumerate(epoch):\n",
    "                tr_losses, training_loss_, training_state, _ = \\\n",
    "                    sess.run([losses, total_loss, final_state, train_step], feed_dict={x:X, y:Y, init_state:training_state})\n",
    "                training_loss += training_loss_\n",
    "                if step % 100 == 0 and step > 0:\n",
    "                    if verbose:\n",
    "                        print('第 {0} 步的平均损失 {1}'.format(step, training_loss/100))\n",
    "                    training_losses.append(training_loss/100)\n",
    "                    training_loss = 0\n",
    "    return training_losses\n",
    "\n",
    "training_losses = train_rnn(num_epochs=5, num_steps=num_steps, state_size=state_size)\n",
    "print(training_losses[0])\n",
    "plt.plot(training_losses)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
